-
Notifications
You must be signed in to change notification settings - Fork 52
feat: add API specification for returning the k largest elements
#722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@ogrisel Do you have opinions on whether having three separate APIs is preferable to having just I tried tracking down actual usages of top There may also be other combinations. The PR currently specifies
We could, e.g., only specify
or strictly complementary
If you have a feel for what is preferable, that would be great to hear! |
|
Thank you very much for the survey of current implementations and API proposal. From a potential API consumer point of view the main proposal seems good for me. About the name:
I believe so. As a user I would accept the call to fail with a standard exception type such as
I think it's fine to allow topk to be faster by not enforcing this constraint. It would always be possible to add
I wouldn't like this requirement to hurt speed of adoption by backing libraries. I don't think many users need that in practice.
I don't think users have an need for this: then can flatten the input by themselves if need. But I have the impression that NumPy (and therefore Array API) often However for this particular case, the default in
That would remove a lot of value by preventing efficient parallelization by the underlying backend when running topk on a 2D array with
The benefit for the 3 function API would be to always be able to optimize memory usage by not allocating unnecessary arrays when k is large enough for this to matter. However I have no good feeling about how much this would really be a problem in practice (and how much the underlying implementation would be able to skip the extra contiguous memory allocation internally). Something that seems missing from this spec is to specify the handling of NaN values. Maybe and extra kwarg is needed to specify if they should be considered as either smallest or largest, or if they need to be filtere out from the result (but then the result size would be data-dependent and potentially empty arrays which might also cause problems). Also I assume that nan values are always smaller than +inf and larger than -inf but maybe not all libraries agree on that. |
|
Revisiting this topic in preparation of helping it move forward. Quick first comment on:
I am not sure if it changed in the meantime or you just misread the JAX docs at the time, but https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.top_k.html says that both I had a peek at the jax.lax implementation as well, but could tell so quickly because I'm not familiar with the So it looks like PyTorch, JAX and TF all return |
|
JAX's current The rendered documentation is misleading due to a misplaced colon in the Returns block, but JAX returns In [2]: x = jax.numpy.arange(100, 110)
In [3]: jax.lax.top_k(x, 3)
Out[3]: [Array([109, 108, 107], dtype=int32), Array([9, 8, 7], dtype=int32)] |
|
Here is a PR with a draft implementation for NumPy: numpy/numpy#26666, aligned with the |
|
The most sane NaN handling IMO, is to sort NaNs always to the end, which however means that if you implement The annoyance with that is that it means sort behavior diverges for asc/desc sort beyong a Of course one can just leave it unspecified here. OTOH, I dunno how much that limits the usability. |
The purpose of this PR is to continue several threads of discussion regarding `top_k`. This follows roughly the specifications of `top_k` in data-apis/array-api#722, with slight modifications to the API: ```py def topk( x: array, k: int, /, axis: Optional[int] = None, *, largest: bool = True, ) -> Tuple[array, array]: ... ``` Modifications: - `mode: Literal["largest", "smallest"]` is replaced with `largest: bool` - `axis` is no longer a kw-only arg. This makes `torch.topk` slightly more compatible. The tests implemented here follows the proposed `top_k` implementation at numpy/numpy#26666.
Co-authored-by: ndgrigorian <[email protected]>
k largest elementsk largest elements
| /, | ||
| *, | ||
| axis: Optional[int] = None, | ||
| mode: Literal["largest", "smallest"] = "largest", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the value of using a string literal for this toggle? Are we anticipating other options? Why is string literals better than using an enum?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we're expecting other options. String literals is simply a standard way of doing this when only certain string values are valid.
Why is string literals better than using an enum?
Enums are awful for defining a public API. You'll need to make the enums themselves public so the users of your API can use them, meaning you will increase the size of your API for every argument with a fixed set of options you add. Keeping the API surface small and easy to understand is an important goal of this standard - it's mostly functions and a few constants and other objects.
| """ | ||
|
|
||
|
|
||
| def top_k( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally this look good. Other implementations I've seen pass a DeviceContext as well but I'm not sure if we want to tackle that as part of the initial implementation.
Update: This PR
resolves RFC: add topk and / or argpartition #629 by adding one new API to the Array API specification
top_k: returns a tuple whose first element is an array containing the topklargest (or smallest) values and whose second element is an array containing the indices of thosekvalues.The design decisions largely follow the discussion below. The specialized methods
top_k_indicesandtop_k_valueshave been dropped from the specification.This PR
resolves RFC: add topk and / or argpartition #629 by adding
3new APIs to the Array API specificationtop_k: returns a tuple whose first element is an array containing the topklargest (or smallest) values and whose second element is an array containing the indices of thosekvalues.top_k_indices: returns an array containing the indices of theklargest (or smallest) values.top_k_values: returns an array containing theklargest (or smallest) values.Prior Art
As illustrated in the API comparison, there is currently no consistent API across array libraries for returning the
klargest or smallest values.partitionandargpartition, but these return full arrays. WhenaxisisNone, NumPy operates on a flattened input array. To get the topkvalues, one must index and, if wanting sorted values, sort.partitionandargpartitionand follows NumPy; however, forargpartition, for implementation reasons, it performs a full sort.topkand switches "modes" (largest or smallest) based on whetherkis positive or negative.top_kwhichonly returns valuesalways returns both values and indices, as well as NumPy equivalentpartitionandargpartitionAPIs (however, JAX differs in how it handles NaNs). The function only supports searching along the last axis.topkwhich always returns both values and indices.top_kwhich always returns both values and indices and only supports searching along the last axis.Proposed APIs
This PR attempts to synthesize the common themes and best ideas for "top k" APIs as observed among array libraries and attempts to define APIs which adhere to specification precedent in order to promote consistent design and reduce cognitive load.
top_k
Returns a tuple containing the
klargest (or smallest) elements inx.Returns an array containing the indices of the
klargest (or smallest) elements inx.Returns an array containing the
klargest (or smallest) elements inx.Design Decision Rationale
axisisNonein order to matchmin,max,argmin, andargmax. In those APIs, whenaxisisNone(the default), the functions operate over a flattened array. Given thattop_k*may be considered a generalization of the mentioned APIs, ensuring consistency seemed preferable to requiring users to remember a separate set of rules fortop_k*.axisonly supportsintandNonein order to matchargminandargmax. Inminandmax, the specification supports specifying multiple axes. Support for multiple axes can be a future specification extension.top_kwas chosen overtopkdue to naming concerns discussed elsewhere (namelytop kvsto pk). Furthermore, "top k" follows ML conventions, as opposed tomaxk/max_kornlargest/nsmallestas found in other languages.unique. In that case, rather than support polymorphic return values (e.g., returning values, returning values and indices, return values and counts, etc), we chose define specific API which are monomorphic in their output. We innovated there, and the thinking that went into those design decisions seemed applicable here, where a user may want only values, indices, or both.unique_*naming convention, rather than thearg*naming convention, as there are three different return value situations: values, indices, and indices and values. Hence, using a suffix to describe what is returned as inunique_*seems reasonable and follows existing precedent in the specification.largestkeyword argument. This PR chooses to name the kwargmodein order to be more explicit (what doeslargest=Falsemean to the lay reader?) and follows precedent elsewhere in the specification (e.g.,linalg.qr) wheremodeis used to toggle between different operating modes.sortedkwarg in order to instruct the API to return sorted values (or indices corresponding to sorted values) because (a) the kwarg is not universally supported currently, (b) downstream users can, at least for values, explicitly callsort(except in Dask which doesn't currently support full sorting) after callingtop_kortop_k_values, and (c) can be addressed in a future specification extension. Additionally, if we supportsorted, we may also want to support astablekwarg as insortto allow ensuring that returned indices are consistent when provided the same input array.kexceeds the number of elements, as different behaviors seem acceptable (e.g., raising an exception or returningm < kvalues).Questions
kexceeds the number of elements?argminandargmax, the specification requires returning the index of first occurrence when a minimum/maximum value occurs multiple times. Given thattop_k*can be implemented as a partial sort, presumably we do not want specify a first occurrence restriction. Is this a reasonable assumption?minandmax?Nonebeing the default foraxis, where the default behavior is searching over a flattened array?Considerations
The APIs included in this PR have implications for the following array libraries:
unique_*, will need to be added to the main namespace.topkandargtopk.laxnamespace, this PR would introduce breaking changes, asJAX would need to return both values and indices, by default,and JAX would need to flatten by default rather than search along the last dimension. However, if implemented in itsnumpynamespace, these will simply be new APIs. In both scenarios, JAX will need to add support foraxisandmodebehavior.topk.top_k_valuesandtop_k_indices). If implemented in itsmathnamespace, this PR would introduce breaking changes as TensorFlow would need to flatten by default. However, if implemented in itsnumpynamespace, these will simply be new APIs. In both scenarios, TensorFlow will need to add support foraxisandmodebehavior.Related Links